# -*- coding: utf-8 -*-  
'''
初始化cpt2.0

@author: luoyi
Created on 2021年4月7日
'''
import utils.conf as conf
import data.dataset_bert as ds_bert


#    写入训练集数据
print('开始写入train数据集...')
pret_iterator_train = ds_bert.sample_generator(f_path=conf.DATASET.get_in_train(), 
                                               s_path=conf.DATASET.get_label_train(), 
                                               count=conf.DATASET.get_count_train(), 
                                               neg_prob=conf.BERT.get_neg_prob(), 
                                               rewrite_prob=conf.BERT.get_rewrite_prob(), 
                                               rewrite_mask=conf.BERT.get_rewrite_mask(), 
                                               rewrite_original=conf.BERT.get_rewrite_original(), 
                                               rewrite_random=conf.BERT.get_rewrite_random())
ds_bert.save_tfrecord(pret_iterator_train, 
                      tfrecord_dir=conf.BERT.get_pre_training_tfrecord_train(), 
                      tfrecord_limit=conf.BERT.get_pre_training_tfrecord_limit())
print('写入train数据集成功.')


#    写入验证集数据
print('开始写入val数据集...')
pret_iterator_train = ds_bert.sample_generator(f_path=conf.DATASET.get_in_val(), 
                                               s_path=conf.DATASET.get_label_val(), 
                                               count=conf.DATASET.get_count_val(), 
                                               neg_prob=conf.BERT.get_neg_prob(), 
                                               rewrite_prob=conf.BERT.get_rewrite_prob(), 
                                               rewrite_mask=conf.BERT.get_rewrite_mask(), 
                                               rewrite_original=conf.BERT.get_rewrite_original(), 
                                               rewrite_random=conf.BERT.get_rewrite_random())
ds_bert.save_tfrecord(pret_iterator_train, 
                      tfrecord_dir=conf.BERT.get_pre_training_tfrecord_val(), 
                      tfrecord_limit=conf.BERT.get_pre_training_tfrecord_limit())
print('写入val数据集成功.')
